import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import gen_label
import copy
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

def CLAS2(logits, label, seq_len, criterion, sf, X_v, epoch, cfg):
    logits = logits.squeeze()
    ins_logits = torch.zeros(0).cuda()  # tensor([])
    label = list(label.tolist())
    change_label = copy.deepcopy(label)
    abn_idx = 0
    for i in range(logits.shape[0]):
        if label[i] == 0:
            tmp, _ = torch.topk(logits[i][:seq_len[i]], k=1, largest=True)
            tmp = torch.mean(tmp).view(1)
            ins_logits = torch.cat((ins_logits, tmp))
        else:
            #if epoch <= 10:
            sf_anno = sf[i][sf[i] != -1]
            sf_anno = torch.unique(sf_anno)
            sf_anno = sf_anno[(sf_anno >= 0) & (sf_anno < logits[i].shape[0])]
            idxs = sf_anno.to(dtype=torch.long)[0]
            sims = cal_sim(X_v[i][:seq_len[i]], idxs)
            end = get_abn_range(sims, idxs)
            #tmp = logits[i][idxs:end+1] * torch.tensor(split_into_parts(end+1-idxs), dtype=torch.float32).cuda() #  * torch.tensor(np.linspace(1, 0.5, end+1-idxs), dtype=torch.float32).cuda()
            tmp = logits[i][idxs]
            
            #else:
            #    tmp, _ = torch.topk(logits[i][:seq_len[i]], k=int(seq_len[i]//16+1), largest=True)
            # tmp, _ = torch.topk(logits[i][idxs:end+1], k=1, largest=False)
            
            tmp = torch.mean(tmp).view(1)
            ins_logits = torch.cat((ins_logits, tmp))
            if epoch >= 0:
                min_sf = int(min(sf_anno.tolist()))
                gap = cfg.gap
                end = min_sf - gap
                if end >= 0:
                    # nor_idxs = torch.arange(0, end+1, 1).cuda()
                    # nor_tmp = logits[i][nor_idxs]
                    # nor_tmp = torch.mean(nor_tmp).view(1)
                    # tmp, nor_idxs = torch.topk(logits[i][:end+1], k=1, largest=True)
                    # nor_tmp = logits[i][nor_idxs]
                    tmp, nor_idxs = torch.topk(logits[i][:end+1], k=end//cfg.pick+1, largest=True)
                    nor_idxs = torch.sort(nor_idxs)[0]
                    nor_tmp = logits[i][nor_idxs] * torch.tensor(split_into_parts(end+1, sigma=cfg.sigma, duration=cfg.duration), dtype=torch.float32).cuda()[nor_idxs] #  * torch.tensor(np.linspace(1, 0.5, end+1-idxs), dtype=torch.float32).cuda()
                    nor_tmp = torch.sum(nor_tmp).view(1)
                    ins_logits = torch.cat((ins_logits, nor_tmp))
                    change_label.insert(i+abn_idx+1, 0.)
                    abn_idx += 1

    change_label = torch.tensor(change_label).cuda()
    clsloss = criterion(ins_logits, change_label)
    return clsloss

# tmp, _ = torch.topk(logits[i][:seq_len[i]], k=int(seq_len[i]//16+1), largest=True)
def cal_sim(x, key_frame):
    key_frame_vector = x[key_frame]
    if key_frame_vector.dim() == 1:
        key_frame_vector = key_frame_vector.unsqueeze(0)
    sims = F.cosine_similarity(key_frame_vector, x)
    return sims

count = 0
def get_abn_range(sims, key_frame):
    x = -np.array(sims[key_frame:].cpu())
    peaks_, _ = find_peaks(x)
    if len(peaks_) == 0:
        return int(key_frame)
    first_peak = sims[key_frame] - sims[key_frame+peaks_[0]]
    if first_peak < 0.05:
        peaks, _ = find_peaks(x, prominence=float(first_peak))
        if len(peaks) == 0:
            return int(key_frame+peaks_[0])
        return int(key_frame+peaks[0])
    return int(key_frame+peaks_[0])


def split_into_parts(n, mu=0, sigma=0.5, duration=0.6):
    x = np.linspace(0, duration, n)
    pdf = (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
    sum_ = pdf.sum()
    return pdf / sum_


def KLV_loss(preds, label, criterion):
    preds = F.softmax(preds, dim=1)
    preds = torch.log(preds)
    if torch.isnan(preds).any():
        loss = 0
    else:
        # preds = F.log_softmax(preds, dim=1)
        target = F.softmax(label * 10, dim=1)
        loss = criterion(preds, target)

    return loss


def temporal_smooth(arr):
    arr2 = torch.zeros_like(arr)
    arr2[:-1] = arr[1:]
    arr2[-1] = arr[-1]
    loss = torch.sum((arr2-arr)**2)
    return loss


def temporal_sparsity(arr):
    loss = torch.sum(arr)
    # loss = torch.mean(torch.norm(arr, dim=0))
    return loss


def Smooth(logits, seq_len, lamda=8e-5):
    smooth_mse = []
    for i in range(logits.shape[0]):
        tmp_logits = logits[i][:seq_len[i]-1]
        sm_mse = temporal_smooth(tmp_logits)
        smooth_mse.append(sm_mse)
    smooth_mse = sum(smooth_mse) / len(smooth_mse)

    return smooth_mse * lamda


def Sparsity(logits, seq_len, lamda=8e-5):
    spar_mse = []
    for i in range(logits.shape[0]):
        tmp_logits = logits[i][:seq_len[i]]
        sp_mse = temporal_sparsity(tmp_logits)
        spar_mse.append(sp_mse)
    spar_mse = sum(spar_mse) / len(spar_mse)

    return spar_mse * lamda


def Smooth_Sparsity(logits, seq_len, lamda=8e-5):
    smooth_mse = []
    spar_mse = []
    for i in range(logits.shape[0]):
        tmp_logits = logits[i][:seq_len[i]]
        sm_mse = temporal_smooth(tmp_logits)
        sp_mse = temporal_sparsity(tmp_logits)
        smooth_mse.append(sm_mse)
        spar_mse.append(sp_mse)
    smooth_mse = sum(smooth_mse) / len(smooth_mse)
    spar_mse = sum(spar_mse) / len(spar_mse)

    return (smooth_mse + spar_mse) * lamda